"""
Phase 9: Eurozone Deep Dive
============================
1. Within-EMU CA regressions (levels, deviations, distance)
2. Project Z₁ divergence to 2060
3. Regime strain index
4. Yield spread analysis
5. Pre-crisis vs post-crisis
6. Forward EMU counterfactual (P(peg) projections)
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_CSV = PROJECT_DIR / "output" / "tables"  # CSVs go alongside .md tables
OUT_TABLES.mkdir(parents=True, exist_ok=True)

MULTILATERAL_DATA = ROOT_DIR / "multilateral" / "followup" / "data" / "processed"

EUROZONE_JOIN = {
    'AUT': 1999, 'BEL': 1999, 'FIN': 1999, 'FRA': 1999, 'DEU': 1999,
    'IRL': 1999, 'ITA': 1999, 'LUX': 1999, 'NLD': 1999, 'PRT': 1999,
    'ESP': 1999, 'GRC': 2001, 'SVN': 2007, 'CYP': 2008, 'MLT': 2008,
    'SVK': 2009, 'EST': 2011, 'LVA': 2014, 'LTU': 2015,
}
EUROZONE_ISO3 = set(EUROZONE_JOIN.keys())

OECD = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
}


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    s = stars(p)
    return f"{val:.4f}{s}", f"({se:.4f})"


def run_panel_gls(df, y_var, x_vars, label):
    """Run PanelGLS and return results dict."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    try:
        gls.fit(y, X, sub['iso3'].values, sub['year'].values)
    except Exception as e:
        print(f"  {label}: GLS failed ({e}), skipping")
        return None

    result = {
        'model': label,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(x_vars):
        sig = stars(gls.pvalues[i])
        print(f"    {name:30s} {gls.beta[i]:8.4f} ({gls.se[i]:.4f}) {sig}")
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    # Store model object for projections
    result['_gls'] = gls
    result['_x_vars'] = x_vars

    return result


def write_table(results, filename, title, note=None):
    """Write regression results as markdown table."""
    if not results:
        return

    lines = [f"# {title}\n"]

    all_vars = []
    for r in results:
        for k in r:
            if k.endswith('_coef'):
                vname = k.replace('_coef', '')
                if vname not in all_vars:
                    all_vars.append(vname)

    model_labels = [r['model'] for r in results]
    header = "| Variable | " + " | ".join(model_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
    lines.append(header)
    lines.append(sep)

    for var in all_vars:
        coef_row = f"| {var} |"
        se_row = "| |"
        for r in results:
            if f'{var}_coef' in r:
                c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                coef_row += f" {c} |"
                se_row += f" {s} |"
            else:
                coef_row += " |"
                se_row += " |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
    n_row = "| N |"
    r2_row = "| R² |"
    nc_row = "| Countries |"
    for r in results:
        n_row += f" {r['n_obs']} |"
        r2_row += f" {r['r_squared']:.4f} |"
        nc_row += f" {r['n_countries']} |"
    lines.append(n_row)
    lines.append(r2_row)
    lines.append(nc_row)

    if note:
        lines.append(f"\n{note}")
    else:
        lines.append("\n*Panel GLS with country and year fixed effects. "
                     "Standard errors in parentheses.*")
        lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


def get_ez_df(df):
    """Filter to eurozone members post-join-year."""
    ez_rows = []
    for iso3, join_yr in EUROZONE_JOIN.items():
        mask = (df['iso3'] == iso3) & (df['year'] >= join_yr)
        ez_rows.append(df[mask])
    return pd.concat(ez_rows, ignore_index=True)


# ── 1. Within-EMU CA Regressions ───────────────────────────────────

def within_emu_ca(df):
    """Run Z → CA regressions within EMU using levels, deviations, distance."""
    print("\n" + "=" * 60)
    print("1. WITHIN-EMU CA REGRESSIONS")
    print("=" * 60)

    ez_df = get_ez_df(df)
    print(f"  EMU sample: {len(ez_df)} obs, {ez_df['iso3'].nunique()} countries")
    print(f"  Year range: {ez_df['year'].min()}-{ez_df['year'].max()}")

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']

    # Compute EMU-year means
    emu_means = ez_df.groupby('year')[z_vars + ['ca_gdp']].mean()
    emu_means.columns = [f'{c}_emu_mean' for c in emu_means.columns]
    ez_dev = ez_df.merge(emu_means, on='year', how='left')

    for z in z_vars:
        ez_dev[f'{z}_dev'] = ez_dev[z] - ez_dev[f'{z}_emu_mean']
    ez_dev['ca_dev'] = ez_dev['ca_gdp'] - ez_dev['ca_gdp_emu_mean']

    # Distance from DEU (as reference)
    deu_z = ez_df[ez_df['iso3'] == 'DEU'][['year'] + z_vars].set_index('year')
    deu_z.columns = [f'{c}_deu' for c in deu_z.columns]
    ez_dev = ez_dev.merge(deu_z, on='year', how='left')
    for z in z_vars:
        ez_dev[f'{z}_dist_deu'] = ez_dev[z] - ez_dev[f'{z}_deu']

    # Age decomposition deviations
    for av in ['old_dep', 'youth_dep']:
        if av in ez_df.columns:
            am = ez_df.groupby('year')[av].mean()
            am.name = f'{av}_emu_mean'
            ez_dev = ez_dev.merge(am, on='year', how='left')
            ez_dev[f'{av}_dev'] = ez_dev[av] - ez_dev[f'{av}_emu_mean']

    results = []

    # M1: Z levels → CA
    r = run_panel_gls(ez_dev, 'ca_gdp', z_vars + controls, 'Z levels')
    if r: results.append(r)

    # M2: Z deviations → CA deviation
    z_dev_vars = [f'{z}_dev' for z in z_vars]
    r = run_panel_gls(ez_dev, 'ca_dev', z_dev_vars, 'Z dev → CA dev')
    if r: results.append(r)

    # M3: Z deviations + controls → CA deviation
    r = run_panel_gls(ez_dev, 'ca_dev', z_dev_vars + controls,
                      'Z dev + ctrl')
    if r: results.append(r)

    # M4: Distance from DEU → CA
    z_dist_vars = [f'{z}_dist_deu' for z in z_vars]
    avail_dist = [v for v in z_dist_vars if v in ez_dev.columns]
    if avail_dist:
        r = run_panel_gls(ez_dev, 'ca_gdp', avail_dist + controls,
                          'Dist from DEU')
        if r: results.append(r)

    # M5: Age decomposition deviations
    age_dev_vars = [f'{av}_dev' for av in ['old_dep', 'youth_dep']
                    if f'{av}_dev' in ez_dev.columns]
    if age_dev_vars:
        r = run_panel_gls(ez_dev, 'ca_dev', age_dev_vars, 'Age dev')
        if r: results.append(r)

    write_table(results, "phase9_within_emu_ca.md",
                "Within-EMU CA Regressions: Levels, Deviations, Distance",
                note=("*Panel GLS with country and year fixed effects. "
                      "EMU members post-accession only. Deviations computed "
                      "from EMU-year cross-sectional means.*\n"
                      "*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*"))

    return ez_dev, results


# ── 2. Project Z₁ Divergence to 2060 ───────────────────────────────

def project_z_divergence(ez_dev):
    """Project Z₁ for EMU members out to 2060 using UN projections."""
    print("\n" + "=" * 60)
    print("2. PROJECT Z₁ DIVERGENCE TO 2060")
    print("=" * 60)

    fp = pd.read_csv(MULTILATERAL_DATA / "full_panel.csv")
    ez_fp = fp[fp['iso3'].isin(EUROZONE_ISO3)].copy()

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    decades = [2000, 2010, 2020, 2030, 2040, 2050, 2060]

    # Country × decade pivot for Z₁
    proj_rows = []
    for iso3 in sorted(EUROZONE_ISO3):
        cdata = ez_fp[ez_fp['iso3'] == iso3]
        for yr in decades:
            row_data = cdata[cdata['year'] == yr]
            if len(row_data) == 0:
                continue
            proj_rows.append({
                'iso3': iso3,
                'year': yr,
                'Z_1': row_data['Z_1'].values[0],
                'Z_2': row_data['Z_2'].values[0],
                'Z_3': row_data['Z_3'].values[0],
            })

    proj_df = pd.DataFrame(proj_rows)

    if len(proj_df) == 0:
        print("  No projection data available")
        return

    # Compute EMU mean and deviations per decade
    emu_decade_mean = proj_df.groupby('year')['Z_1'].mean().rename('Z_1_emu_mean')
    proj_df = proj_df.merge(emu_decade_mean, on='year', how='left')
    proj_df['Z_1_dev'] = proj_df['Z_1'] - proj_df['Z_1_emu_mean']

    # Cross-sectional dispersion per decade
    disp = proj_df.groupby('year')['Z_1'].agg(['std', 'mean', 'count'])
    disp.columns = ['Z_1_std', 'Z_1_mean', 'n_countries']
    print("\n  EMU Z₁ Dispersion by Decade:")
    print(disp.to_string(float_format='%.4f'))

    # Pivot: countries × decades
    pivot = proj_df.pivot(index='iso3', columns='year', values='Z_1')
    pivot.columns = [f'Z1_{int(c)}' for c in pivot.columns]
    print("\n  EMU Z₁ Levels by Country and Decade:")
    print(pivot.to_string(float_format='%.3f'))

    # Deviation pivot
    dev_pivot = proj_df.pivot(index='iso3', columns='year', values='Z_1_dev')
    dev_pivot.columns = [f'dev_{int(c)}' for c in dev_pivot.columns]

    # Which countries diverge most by 2040-2060?
    if 'dev_2040' in dev_pivot.columns:
        print("\n  Largest Z₁ Deviations from EMU Mean by 2040:")
        sorted_2040 = dev_pivot['dev_2040'].abs().sort_values(ascending=False)
        for iso3 in sorted_2040.head(10).index:
            val = dev_pivot.loc[iso3, 'dev_2040']
            direction = "older than EMU" if val > 0 else "younger than EMU"
            print(f"    {iso3}: {val:+.3f} ({direction})")

    # Save CSV
    combined = pivot.merge(dev_pivot, left_index=True, right_index=True)
    combined.to_csv(OUT_CSV / "phase9_emu_z_projections.csv")
    print(f"\n  Saved: {OUT_CSV / 'phase9_emu_z_projections.csv'}")

    # Write summary markdown
    lines = ["# EMU Z₁ Projections to 2060\n"]
    lines.append("## Cross-Sectional Dispersion\n")
    lines.append("| Decade | Z₁ Mean | Z₁ Std Dev | N Countries |")
    lines.append("|---:|---:|---:|---:|")
    for yr, row in disp.iterrows():
        lines.append(f"| {int(yr)} | {row['Z_1_mean']:.4f} | "
                     f"{row['Z_1_std']:.4f} | {int(row['n_countries'])} |")

    lines.append("\n## Country Z₁ Levels\n")
    lines.append("| Country | " + " | ".join([f"{int(d)}" for d in decades
                  if f'Z1_{int(d)}' in pivot.columns]) + " |")
    lines.append("|:---|" + "|".join(["---:" for d in decades
                  if f'Z1_{int(d)}' in pivot.columns]) + "|")
    for iso3 in sorted(pivot.index):
        row_str = f"| {iso3} |"
        for d in decades:
            col = f'Z1_{int(d)}'
            if col in pivot.columns and pd.notna(pivot.loc[iso3, col]):
                row_str += f" {pivot.loc[iso3, col]:.3f} |"
            else:
                row_str += " |"
        lines.append(row_str)

    lines.append("\n*Z₁ values from UN WPP population projections. "
                 "Higher Z₁ indicates older demographic structure.*")

    path = OUT_TABLES / "phase9_emu_projection_summary.md"
    path.write_text('\n'.join(lines))
    print(f"  Saved: {path}")

    return proj_df


# ── 3. Regime Strain Index ──────────────────────────────────────────

def regime_strain(ez_dev, ca_results, proj_df):
    """Compute regime strain using DEVIATION coefficients.

    Uses Z_dev → CA_dev coefficients so the table shows cross-sectional
    spread (who runs surplus vs deficit relative to EMU mean), not
    all-deficit levels from the aging trend.
    """
    print("\n" + "=" * 60)
    print("3. REGIME STRAIN INDEX (DEVIATION-BASED)")
    print("=" * 60)

    # Use the Z dev + ctrl result for coefficients
    z_dev_result = None
    for r in ca_results:
        if r and r['model'] == 'Z dev + ctrl':
            z_dev_result = r
            break
    # Fallback to Z dev without controls
    if z_dev_result is None:
        for r in ca_results:
            if r and r['model'] == 'Z dev → CA dev':
                z_dev_result = r
                break

    if z_dev_result is None:
        print("  No within-EMU Z deviation result available, skipping")
        return

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    z_dev_vars = [f'{z}_dev' for z in z_vars]
    z_coefs = {z: z_dev_result[f'{z}_coef'] for z in z_dev_vars}

    print(f"  Using model: {z_dev_result['model']}")
    for z in z_dev_vars:
        print(f"    {z} coef = {z_coefs[z]:.4f}")

    if proj_df is None or len(proj_df) == 0:
        print("  No projection data, skipping")
        return

    # proj_df already has Z_1_dev (deviation from EMU mean per decade)
    # We need Z_2_dev and Z_3_dev too
    for z in ['Z_2', 'Z_3']:
        if f'{z}_dev' not in proj_df.columns:
            emu_mean = proj_df.groupby('year')[z].mean().rename(f'{z}_emu_mean')
            proj_df = proj_df.merge(emu_mean, on='year', how='left')
            proj_df[f'{z}_dev'] = proj_df[z] - proj_df[f'{z}_emu_mean']

    # Compute predicted CA deviation for each country-decade
    strain_rows = []
    decades = [2000, 2010, 2020, 2030, 2040, 2050, 2060]

    for iso3 in sorted(EUROZONE_ISO3):
        for yr in decades:
            cyr = proj_df[(proj_df['iso3'] == iso3) & (proj_df['year'] == yr)]
            if len(cyr) == 0:
                continue

            predicted_ca_dev = sum(
                z_coefs[f'{z}_dev'] * cyr[f'{z}_dev'].values[0]
                for z in z_vars
            )
            strain_rows.append({
                'iso3': iso3,
                'year': yr,
                'Z_1_dev': cyr['Z_1_dev'].values[0],
                'predicted_ca_dev': predicted_ca_dev,
                'strain': abs(predicted_ca_dev),
                'direction': 'surplus' if predicted_ca_dev > 0 else 'deficit',
            })

    strain_df = pd.DataFrame(strain_rows)

    if len(strain_df) == 0:
        print("  No strain data computed")
        return

    # Rank by 2040 strain
    strain_2040 = strain_df[strain_df['year'] == 2040].sort_values(
        'predicted_ca_dev', ascending=True)
    print("\n  Regime Strain Ranking (2040, deviation from EMU mean):")
    for _, row in strain_2040.iterrows():
        print(f"    {row['iso3']}: CA_dev={row['predicted_ca_dev']:+.2f} "
              f"(Z₁_dev={row['Z_1_dev']:+.3f}, {row['direction']})")

    # Save CSV
    strain_df.to_csv(OUT_CSV / "phase9_regime_strain.csv", index=False)
    print(f"\n  Saved: {OUT_CSV / 'phase9_regime_strain.csv'}")

    # Pivot table for markdown
    pivot = strain_df.pivot(index='iso3', columns='year',
                            values='predicted_ca_dev')
    pivot.columns = [f'{int(c)}' for c in pivot.columns]

    lines = ["# EMU Regime Strain Index (Deviation-Based)\n"]
    lines.append("Predicted CA/GDP *deviation from EMU mean* using Z deviation "
                 "coefficients.\n")
    lines.append("Positive = demographically-driven surplus relative to EMU mean. "
                 "Negative = demographically-driven deficit.\n")
    lines.append("Strain = |predicted CA deviation|. Larger absolute values "
                 "indicate greater pressure from demographic mismatch.\n")

    lines.append("| Country | " + " | ".join(pivot.columns) + " |")
    lines.append("|:---|" + "|".join(["---:" for _ in pivot.columns]) + "|")
    for iso3 in strain_2040['iso3']:
        if iso3 not in pivot.index:
            continue
        row_str = f"| {iso3} |"
        for col in pivot.columns:
            val = pivot.loc[iso3, col]
            if pd.notna(val):
                row_str += f" {val:+.2f} |"
            else:
                row_str += " |"
        lines.append(row_str)

    lines.append(f"\n*Coefficients from within-EMU Z deviation regression "
                 f"(N={z_dev_result['n_obs']}). Z deviations = country Z minus "
                 f"EMU cross-sectional mean per year. "
                 f"Z projections from UN WPP.*")

    path = OUT_TABLES / "phase9_regime_strain.md"
    path.write_text('\n'.join(lines))
    print(f"  Saved: {path}")

    return strain_df


# ── 4. Yield Spread Analysis ───────────────────────────────────────

def yield_spreads(df):
    """Analyze bond yield spreads vs Bund within eurozone."""
    print("\n" + "=" * 60)
    print("4. YIELD SPREAD ANALYSIS")
    print("=" * 60)

    ez_df = get_ez_df(df)

    # Compute spread vs DEU
    deu_yields = df[df['iso3'] == 'DEU'][['year', 'govt_bond_10y']].copy()
    deu_yields.columns = ['year', 'bund_10y']
    ez_spreads = ez_df.merge(deu_yields, on='year', how='left')
    ez_spreads['spread_vs_deu'] = ez_spreads['govt_bond_10y'] - ez_spreads['bund_10y']

    # Exclude DEU itself
    ez_spreads = ez_spreads[ez_spreads['iso3'] != 'DEU'].copy()

    n_spread = ez_spreads['spread_vs_deu'].notna().sum()
    countries_with_spreads = ez_spreads[ez_spreads['spread_vs_deu'].notna()]['iso3'].unique()
    print(f"  Spread observations: {n_spread}")
    print(f"  Countries with spreads: {sorted(countries_with_spreads)}")

    if n_spread < 50:
        print("  Insufficient spread data, skipping")
        return

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']

    # Compute deviations from EMU mean
    emu_means = ez_spreads.groupby('year')[z_vars].mean()
    emu_means.columns = [f'{c}_emu_mean' for c in emu_means.columns]
    ez_spreads = ez_spreads.merge(emu_means, on='year', how='left')
    for z in z_vars:
        ez_spreads[f'{z}_dev'] = ez_spreads[z] - ez_spreads[f'{z}_emu_mean']

    results = []

    # M1: Z levels → spread
    r = run_panel_gls(ez_spreads, 'spread_vs_deu', z_vars + controls,
                      'Z levels')
    if r: results.append(r)

    # M2: Z deviations → spread
    z_dev_vars = [f'{z}_dev' for z in z_vars]
    r = run_panel_gls(ez_spreads, 'spread_vs_deu', z_dev_vars + controls,
                      'Z deviations')
    if r: results.append(r)

    # M3: Age decomposition → spread
    age_vars = ['old_dep', 'youth_dep']
    avail_age = [v for v in age_vars if v in ez_spreads.columns]
    if avail_age:
        r = run_panel_gls(ez_spreads, 'spread_vs_deu', avail_age + controls,
                          'Age decomp')
        if r: results.append(r)

    # M4: Without GRC (outlier)
    ez_no_grc = ez_spreads[ez_spreads['iso3'] != 'GRC'].copy()
    r = run_panel_gls(ez_no_grc, 'spread_vs_deu', z_vars + controls,
                      'Excl GRC')
    if r: results.append(r)

    write_table(results, "phase9_yield_spreads.md",
                "Eurozone Yield Spreads vs Bund: Demographic Determinants",
                note=("*Panel GLS with country and year fixed effects. "
                      "DV = 10-year government bond yield minus German Bund. "
                      "EMU members post-accession, excluding DEU.*\n"
                      "*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*"))


# ── 5. Pre-Crisis vs Post-Crisis ───────────────────────────────────

def pre_post_crisis(df):
    """Split 1999-2007 vs 2010-2024 and compare demographic CA effects."""
    print("\n" + "=" * 60)
    print("5. PRE-CRISIS vs POST-CRISIS")
    print("=" * 60)

    ez_df = get_ez_df(df)

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']

    # Pre-crisis: 1999-2007
    pre = ez_df[(ez_df['year'] >= 1999) & (ez_df['year'] <= 2007)].copy()
    # Post-crisis: 2010-2024
    post = ez_df[(ez_df['year'] >= 2010) & (ez_df['year'] <= 2024)].copy()

    print(f"  Pre-crisis (1999-2007): {len(pre)} obs, "
          f"{pre['iso3'].nunique()} countries")
    print(f"  Post-crisis (2010-2024): {len(post)} obs, "
          f"{post['iso3'].nunique()} countries")

    results = []

    # Full EMU period
    r = run_panel_gls(ez_df, 'ca_gdp', z_vars + controls, 'Full EMU')
    if r: results.append(r)

    # Pre-crisis
    r = run_panel_gls(pre, 'ca_gdp', z_vars + controls, 'Pre-crisis')
    if r: results.append(r)

    # Post-crisis
    r = run_panel_gls(post, 'ca_gdp', z_vars + controls, 'Post-crisis')
    if r: results.append(r)

    # Age decomposition: pre vs post
    age_vars = ['old_dep', 'youth_dep']
    avail_age = [v for v in age_vars if v in ez_df.columns]
    if avail_age:
        r = run_panel_gls(pre, 'ca_gdp', avail_age + controls,
                          'Pre: Age')
        if r: results.append(r)

        r = run_panel_gls(post, 'ca_gdp', avail_age + controls,
                          'Post: Age')
        if r: results.append(r)

    write_table(results, "phase9_pre_post_crisis.md",
                "Pre-Crisis vs Post-Crisis: EMU Demographic CA Effects",
                note=("*Panel GLS with country and year fixed effects. "
                      "Pre-crisis: 1999-2007 (convergence era). "
                      "Post-crisis: 2010-2024 (post-sovereign debt crisis).*\n"
                      "*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*"))


# ── 6. Forward EMU Counterfactual ───────────────────────────────────

def forward_counterfactual(df):
    """Project P(peg) for EMU members at 2030/2040/2050."""
    print("\n" + "=" * 60)
    print("6. FORWARD EMU COUNTERFACTUAL")
    print("=" * 60)

    from scipy.optimize import minimize
    from scipy import stats as sp_stats

    fp = pd.read_csv(MULTILATERAL_DATA / "full_panel.csv")

    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'kaopen']
    x_vars = ['Z_1', 'Z_2', 'Z_3'] + controls

    # Training sample: OECD non-eurozone, post-1999
    train_df = df[(df['iso3'].isin(OECD)) &
                  (~df['iso3'].isin(EUROZONE_ISO3)) &
                  (df['year'] >= 1999)].copy()

    if 'is_peg' not in train_df.columns:
        train_df['is_peg'] = (train_df['regime_3cat'] == 1).astype(float)

    cols = ['is_peg'] + x_vars + ['iso3']
    train_sub = train_df[cols].dropna()
    print(f"  Training sample (OECD non-EZ): {len(train_sub)} obs, "
          f"{train_sub['iso3'].nunique()} countries")

    y_train = train_sub['is_peg'].values.astype(float)
    X_train = np.column_stack([np.ones(len(train_sub)),
                                train_sub[x_vars].values.astype(float)])
    n, k = X_train.shape

    if y_train.sum() < 5 or (1 - y_train).sum() < 5:
        print("  Insufficient outcome variation, skipping")
        return

    # Standardize for estimation
    x_means = X_train[:, 1:].mean(axis=0)
    x_stds = X_train[:, 1:].std(axis=0)
    x_stds[x_stds == 0] = 1
    X_train_std = X_train.copy()
    X_train_std[:, 1:] = (X_train[:, 1:] - x_means) / x_stds

    def neg_log_likelihood(beta):
        z = X_train_std @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        p = np.clip(p, 1e-12, 1 - 1e-12)
        return -np.sum(y_train * np.log(p) + (1 - y_train) * np.log(1 - p))

    def gradient(beta):
        z = X_train_std @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        return -X_train_std.T @ (y_train - p)

    beta0 = np.zeros(k)
    try:
        opt_result = minimize(neg_log_likelihood, beta0, jac=gradient,
                             method='BFGS', options={'maxiter': 1000, 'gtol': 1e-6})
        beta_std = opt_result.x
        beta = np.zeros(k)
        beta[1:] = beta_std[1:] / x_stds
        beta[0] = beta_std[0] - np.sum(beta_std[1:] * x_means / x_stds)
    except Exception as e:
        print(f"  Logit estimation failed: {e}")
        return

    print(f"  Logit converged: {opt_result.success}")

    # Get last observed control values per EMU country
    last_controls = (df[df['iso3'].isin(EUROZONE_ISO3)]
                     .dropna(subset=controls)
                     .sort_values('year')
                     .groupby('iso3')[controls]
                     .last())

    # Report missing countries
    missing = EUROZONE_ISO3 - set(last_controls.index)
    if missing:
        print(f"  Missing control data for: {sorted(missing)}")

    # Project P(peg) for each EMU member at future decades
    proj_years = [2020, 2030, 2040, 2050]
    proj_rows = []

    for iso3 in sorted(EUROZONE_ISO3):
        ctrl = last_controls.loc[iso3] if iso3 in last_controls.index else None
        if ctrl is None:
            print(f"    {iso3}: skipped (no control data)")
            continue

        cdata = fp[fp['iso3'] == iso3]

        for yr in proj_years:
            yr_data = cdata[cdata['year'] == yr]
            if len(yr_data) == 0:
                continue

            # Build feature vector: [1, Z_1, Z_2, Z_3, controls]
            x_vec = np.zeros(k)
            x_vec[0] = 1  # intercept
            for i, xv in enumerate(x_vars):
                if xv in ['Z_1', 'Z_2', 'Z_3']:
                    x_vec[i + 1] = yr_data[xv].values[0]
                else:
                    x_vec[i + 1] = ctrl[xv]

            z_score = x_vec @ beta
            z_score = np.clip(z_score, -30, 30)
            p_peg = 1 / (1 + np.exp(-z_score))

            proj_rows.append({
                'iso3': iso3,
                'year': yr,
                'p_peg': p_peg,
                'predicted_regime': 'Peg' if p_peg >= 0.5 else 'Float',
            })

    proj_df = pd.DataFrame(proj_rows)

    if len(proj_df) == 0:
        print("  No projections computed")
        return

    # Pivot and display
    pivot = proj_df.pivot(index='iso3', columns='year', values='p_peg')
    pivot.columns = [f'P(peg)_{int(c)}' for c in pivot.columns]

    print("\n  Forward P(peg) Projections for EMU Members:")
    print(pivot.to_string(float_format='%.3f'))

    # Count who would float by decade
    for yr in proj_years:
        yr_proj = proj_df[proj_df['year'] == yr]
        n_peg = (yr_proj['p_peg'] >= 0.5).sum()
        n_float = (yr_proj['p_peg'] < 0.5).sum()
        n_total = len(yr_proj)
        print(f"  {yr}: {n_peg}/{n_total} peg, {n_float}/{n_total} float")

    # Trend: does aging make euro MORE or LESS natural?
    if 'P(peg)_2020' in pivot.columns and 'P(peg)_2050' in pivot.columns:
        delta = pivot['P(peg)_2050'] - pivot['P(peg)_2020']
        n_more_natural = (delta > 0).sum()
        n_less_natural = (delta < 0).sum()
        print(f"\n  2020→2050 trend: {n_more_natural} countries become MORE "
              f"natural peggers, {n_less_natural} become LESS natural")

    # Write markdown
    lines = ["# Forward EMU Counterfactual: P(peg) Projections\n"]
    lines.append("Logit trained on OECD non-eurozone (post-1999): "
                 "is_peg = f(Z, controls)\n")
    lines.append("Controls held at last observed values; demographics from "
                 "UN WPP projections.\n")

    lines.append("| Country | " + " | ".join([f"P(peg) {yr}" for yr in proj_years]) + " |")
    lines.append("|:---|" + "|".join(["---:" for _ in proj_years]) + "|")
    for iso3 in sorted(pivot.index):
        row_str = f"| {iso3} |"
        for yr in proj_years:
            col = f'P(peg)_{yr}'
            if col in pivot.columns and pd.notna(pivot.loc[iso3, col]):
                row_str += f" {pivot.loc[iso3, col]:.3f} |"
            else:
                row_str += " |"
        lines.append(row_str)

    for yr in proj_years:
        yr_proj = proj_df[proj_df['year'] == yr]
        n_peg = (yr_proj['p_peg'] >= 0.5).sum()
        n_float = (yr_proj['p_peg'] < 0.5).sum()
        n_total = len(yr_proj)
        lines.append(f"\n**{yr}**: {n_peg}/{n_total} predicted peg, "
                     f"{n_float}/{n_total} predicted float")

    # Note missing countries
    projected_isos = set(proj_df['iso3'].unique())
    missing_from_proj = EUROZONE_ISO3 - projected_isos
    if missing_from_proj:
        lines.append(f"\n*Note: {', '.join(sorted(missing_from_proj))} excluded "
                     f"due to missing control variables (fiscal balance, NFA, "
                     f"growth, or KAOPEN).*")

    lines.append("\n*Higher P(peg) means demographics more consistent with "
                 "choosing a fixed exchange rate. Values below 0.5 suggest "
                 "the country's demographics favor floating.*")
    lines.append("\n*Note: Z₁ values exhibit a discrete jump around 2035-2040 "
                 "reflecting the baby-boom cohort entering old age simultaneously "
                 "across countries. This produces non-monotonic P(peg) paths "
                 "for some members.*")

    path = OUT_TABLES / "phase9_forward_counterfactual.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ── Main ─────────────────────────────────────────────────────────────

def main():
    print("=" * 70)
    print("PHASE 9: EUROZONE DEEP DIVE")
    print("=" * 70)

    df = pd.read_csv(DATA / "trilemma_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    # 1. Within-EMU CA regressions
    ez_dev, ca_results = within_emu_ca(df)

    # 2. Project Z₁ divergence
    proj_df = project_z_divergence(ez_dev)

    # 3. Regime strain index
    regime_strain(ez_dev, ca_results, proj_df)

    # 4. Yield spread analysis
    yield_spreads(df)

    # 5. Pre-crisis vs post-crisis
    pre_post_crisis(df)

    # 6. Forward counterfactual
    forward_counterfactual(df)

    print("\n" + "=" * 70)
    print("PHASE 9 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
